Skip to content

Conversation

@tolgacangoz
Copy link
Contributor

@tolgacangoz tolgacangoz commented Aug 29, 2025

This PR is fixing #12257.

Comparison with the original repo

When I put with torch.amp.autocast('cuda', dtype=torch.bfloat16): onto the transformer only and converted the initial noise's dtype into torch.float32 from torch.bfloat16 in the original repo, the videos seem almost the same. As far as I can see, the original repo's video has an extra blink.

wan.mp4
diffusers.mp4
Try WanSpeechToVideoPipeline!
!git clone https://github.com/tolgacangoz/diffusers.git
%cd diffusers
#!git switch "integrations/wan2.2-s2v"  # This is constantly changing...
!git switch "wan2.2-s2v"
!pip install pip uv -qU
!uv pip install -e ".[dev]" -q
!uv pip install imageio-ffmpeg ftfy decord ninja packaging kernels -q
# For Flash attention 2:
#!uv pip install flash-attn --no-build-isolation
# For Flash attention 3 in diffusers:
#import os
#os.environ["DIFFUSERS_ENABLE_HUB_KERNELS"] = "yes"


import numpy as np
import torch, os
from diffusers import AutoencoderKLWan, WanSpeechToVideoPipeline
from diffusers.utils import export_to_video, load_image, load_audio, load_video
from transformers import Wav2Vec2ForCTC

model_id = "Wan-AI/Wan2.2-S2V-14B-Diffusers"  # will be official
model_id = "tolgacangoz/Wan2.2-S2V-14B-Diffusers"
audio_encoder = Wav2Vec2ForCTC.from_pretrained(model_id, subfolder="audio_encoder", dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanSpeechToVideoPipeline.from_pretrained(
    model_id, vae=vae, audio_encoder=audio_encoder, torch_dtype=torch.bfloat16,
)#.to("cuda")
pipe.enable_model_cpu_offload()
#pipe.transformer.set_attention_backend("flash")  # FA 2
#pipe.transformer.set_attention_backend("_flash_3_hub")  # FA 3

first_frame = load_image("https://raw.githubusercontent.com/Wan-Video/Wan2.2/refs/heads/main/examples/i2v_input.JPG")
audio, sampling_rate = load_audio("https://github.com/Wan-Video/Wan2.2/raw/refs/heads/main/examples/talk.wav")

import math

def get_size_less_than_area(height,
                            width,
                            target_area=1024 * 704,
                            divisor=64):
    if height * width <= target_area:
        # If the original image area is already less than or equal to the target,
        # no resizing is needed—just padding. Still need to ensure that the padded area doesn't exceed the target.
        max_upper_area = target_area
        min_scale = 0.1
        max_scale = 1.0
    else:
        # Resize to fit within the target area and then pad to multiples of `divisor`
        max_upper_area = target_area  # Maximum allowed total pixel count after padding
        d = divisor - 1
        b = d * (height + width)
        a = height * width
        c = d**2 - max_upper_area

        # Calculate scale boundaries using quadratic equation
        min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / (
            2 * a)  # Scale when maximum padding is applied
        max_scale = math.sqrt(max_upper_area /
                                (height * width))  # Scale without any padding

    # We want to choose the largest possible scale such that the final padded area does not exceed max_upper_area
    # Use binary search-like iteration to find this scale
    find_it = False
    for i in range(100):
        scale = max_scale - (max_scale - min_scale) * i / 100
        new_height, new_width = int(height * scale), int(width * scale)

        # Pad to make dimensions divisible by 64
        pad_height = (64 - new_height % 64) % 64
        pad_width = (64 - new_width % 64) % 64
        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left

        padded_height, padded_width = new_height + pad_height, new_width + pad_width

        if padded_height * padded_width <= max_upper_area:
            find_it = True
            break

    if find_it:
        return padded_height, padded_width
    else:
        # Fallback: calculate target dimensions based on aspect ratio and divisor alignment
        aspect_ratio = width / height
        target_width = int(
            (target_area * aspect_ratio)**0.5 // divisor * divisor)
        target_height = int(
            (target_area / aspect_ratio)**0.5 // divisor * divisor)

        # Ensure the result is not larger than the original resolution
        if target_width >= width or target_height >= height:
            target_width = int(width // divisor * divisor)
            target_height = int(height // divisor * divisor)

        return target_height, target_width

height, width = get_size_less_than_area(first_frame.height, first_frame.width, target_area=480*832)

prompt = "Einstein singing a song."

output = pipe(
    image=first_frame, audio=audio, sampling_rate=sampling_rate,
    prompt=prompt, height=height, width=width, num_frames_per_chunk=80,
).frames[0]
export_to_video(output, "video.mp4", fps=16)

import logging, shutil, subprocess

def merge_video_audio(video_path: str, audio_path: str):
    """
    Merge the video and audio into a new video, with the duration set to the shorter of the two,
    and overwrite the original video file.

    Parameters:
    video_path (str): Path to the original video file
    audio_path (str): Path to the audio file
    """
    # set logging
    logging.basicConfig(level=logging.INFO)

    # check
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"video file {video_path} does not exist")
    if not os.path.exists(audio_path):
        raise FileNotFoundError(f"audio file {audio_path} does not exist")

    base, ext = os.path.splitext(video_path)
    temp_output = f"{base}_temp{ext}"

    try:
        # create ffmpeg command
        command = [
            'ffmpeg',
            '-y',  # overwrite
            '-i',
            video_path,
            '-i',
            audio_path,
            '-c:v',
            'copy',  # copy video stream
            '-c:a',
            'aac',  # use AAC audio encoder
            '-b:a',
            '192k',  # set audio bitrate (optional)
            '-map',
            '0:v:0',  # select the first video stream
            '-map',
            '1:a:0',  # select the first audio stream
            '-shortest',  # choose the shortest duration
            temp_output
        ]

        # execute the command
        logging.info("Start merging video and audio...")
        result = subprocess.run(
            command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

        # check result
        if result.returncode != 0:
            error_msg = f"FFmpeg execute failed: {result.stderr}"
            logging.error(error_msg)
            raise RuntimeError(error_msg)

        shutil.move(temp_output, video_path)
        logging.info(f"Merge completed, saved to {video_path}")

    except Exception as e:
        if os.path.exists(temp_output):
            os.remove(temp_output)
        logging.error(f"merge_video_audio failed with error: {e}")

import requests, tempfile
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT

response = requests.get(audio, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
with tempfile.NamedTemporaryFile(delete=False) as talk:
    for chunk in response.iter_content(chunk_size=8192):
        talk.write(chunk)
    talk_file = talk.name

merge_video_audio("video.mp4", talk_file)

@yiyixuxu @sayakpaul @asomoza @dg845 @stevhliu
@WanX-Video-1 @Steven-SWZhang @kelseyee
@SHYuanBest @J4BEZ @okaris @xziayro-ai @teith @luke14free @lopho @arnold408

…date example imports

Add unit tests for WanSpeechToVideoPipeline and WanS2VTransformer3DModel and gguf
The previous audio encoding logic was a placeholder. It is now replaced with a `Wav2Vec2ForCTC` model and processor, including the full implementation for processing audio inputs. This involves resampling and aligning audio features with video frames to ensure proper synchronization.

Additionally, utility functions for loading audio from files or URLs are added, and the `audio_processor` module is refactored to correctly handle audio data types instead of image types.
Introduces support for audio and pose conditioning, replacing the previous image conditioning mechanism. The model now accepts audio embeddings and pose latents as input.

This change also adds two new, mutually exclusive motion processing modules:
- `MotionerTransformers`: A transformer-based module for encoding motion.
- `FramePackMotioner`: A module that packs frames from different temporal buckets for motion representation.

Additionally, an `AudioInjector` module is implemented to fuse audio features into specific transformer blocks using cross-attention.
The `MotionerTransformers` module is removed and its functionality is replaced by a `FramePackMotioner` module and a simplified standard motion processing pipeline.

The codebase is refactored to remove the `einops` dependency, replacing `rearrange` operations with standard PyTorch tensor manipulations for better code consistency.

Additionally, `AdaLayerNorm` is introduced for improved conditioning, and helper functions for Rotary Positional Embeddings (RoPE) are added (probably temporarily) and refactored for clarity and flexibility. The audio injection mechanism is also updated to align with the new model structure.
Removes the calculation of several unused variables and an unnecessary `deepcopy` operation on the latents tensor.

This change also removes the now-unused `deepcopy` import, simplifying the overall logic.
Refactors the `WanS2VTransformer3DModel` for clarity and better handling of various conditioning inputs like audio, pose, and motion.

Key changes:
- Simplifies the `WanS2VTransformerBlock` by removing projection layers and streamlining the forward pass.
- Introduces `after_transformer_block` to cleanly inject audio information after each transformer block, improving code organization.
- Enhances the main `forward` method to better process and combine multiple conditioning signals (image, audio, motion) before the transformer blocks.
- Adds support for a zero-value timestep to differentiate between image and video latents.
- Generalizes temporal embedding logic to support multiple model variations.
Introduces the necessary configurations and state dictionary key mappings to enable the conversion of S2V model checkpoints to the Diffusers format.

This includes:
- A new transformer configuration for the S2V model architecture, including parameters for audio and pose conditioning.
- A comprehensive rename dictionary to map the original S2V layer names to their Diffusers equivalents.
@tolgacangoz
Copy link
Contributor Author

Thanks @J4BEZ, fixed it.

@J4BEZ
Copy link
Contributor

J4BEZ commented Oct 18, 2025

@tolgacangoz Thanks! I am delighted to help☺️

Have a peaceful day!

Added contributor information and enhanced model description.
Added project page link for Wan-S2V model and improved context.

The project page: https://humanaigc.github.io/wan-s2v-webpage/

This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
Copy link
Contributor Author

@tolgacangoz tolgacangoz Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tolgacangoz
Copy link
Contributor Author

This will be my second official pipeline contribution and my fourth overall, yay 🥳

@tin2tin
Copy link

tin2tin commented Nov 7, 2025

Just a word of encouragement. This technology is actually quite good, and I hope it'll be priotized for review soonish. Here's a video I did with it: https://m.youtube.com/watch?v=N7ARyKKwGfc

@zecloud
Copy link

zecloud commented Nov 12, 2025

Hi @tolgacangoz

I appreciate yout hard work, i tried to use your new pipeline but didn't succeed to make it work like i want

Tried to load a lightx2v lora does not succed :

2025-11-11T17:53:45.0020446Z stdout F Error processing message: 'FrozenDict' object has no attribute 'image_dim'
2025-11-11T17:53:45.0020543Z stderr F pipe.load_lora_weights(
2025-11-11T17:53:45.0020630Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 4068, in load_lora_weights
2025-11-11T17:53:45.0020644Z stderr F state_dict = self._maybe_expand_t2v_lora_for_i2v(
2025-11-11T17:53:45.0020655Z stderr F ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-11-11T17:53:45.0020666Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/loaders/lora_pipeline.py", line 3999, in _maybe_expand_t2v_lora_for_i2v
2025-11-11T17:53:45.0020682Z stderr F if transformer.config.image_dim is None:
2025-11-11T17:53:45.0020692Z stderr F ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-11-11T17:53:45.0020706Z stderr F AttributeError: 'FrozenDict' object has no attribute 'image_dim'

Without Lora I tried the pipeline with .to("cuda") or pipe.enable_model_cpu_offload() or enable_group_offload and always the same error :

2025-11-12T10:57:13.8436869Z stderr F pipe = WanSpeechToVideoPipeline.from_pretrained(
2025-11-12T10:57:13.8436878Z stderr F ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-11-12T10:57:13.8436893Z stderr F File "/opt/venv/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
2025-11-12T10:57:13.8436901Z stderr F return fn(*args, **kwargs)
2025-11-12T10:57:13.8436909Z stderr F ^^^^^^^^^^^^^^^^^^^
2025-11-12T10:57:13.8436917Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/pipelines/pipeline_utils.py", line 1021, in from_pretrained
2025-11-12T10:57:13.8436925Z stderr F loaded_sub_model = load_sub_model(
2025-11-12T10:57:13.8436933Z stderr F ^^^^^^^^^^^^^^^
2025-11-12T10:57:13.8436940Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/pipelines/pipeline_loading_utils.py", line 876, in load_sub_model
2025-11-12T10:57:13.8437033Z stderr F loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
2025-11-12T10:57:13.8437047Z stderr F ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-11-12T10:57:13.8437058Z stderr F File "/opt/venv/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
2025-11-12T10:57:13.8437077Z stderr F return fn(*args, **kwargs)
2025-11-12T10:57:13.8437084Z stderr F ^^^^^^^^^^^^^^^^^^^
2025-11-12T10:57:13.8437092Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/models/modeling_utils.py", line 1316, in from_pretrained
2025-11-12T10:57:13.8437099Z stderr F dispatch_model(model, **device_map_kwargs)
2025-11-12T10:57:13.8437107Z stderr F File "/opt/venv/lib/python3.11/site-packages/accelerate/big_modeling.py", line 502, in dispatch_model
2025-11-12T10:57:13.8437115Z stderr F model.to(device)
2025-11-12T10:57:13.8437123Z stderr F File "/opt/venv/lib/python3.11/site-packages/diffusers/models/modeling_utils.py", line 1424, in to
2025-11-12T10:57:13.8437131Z stderr F return super().to(*args, **kwargs)
2025-11-12T10:57:13.8437138Z stderr F ^^^^^^^^^^^^^^^^^^^^^^^^^^^
2025-11-12T10:57:13.8437146Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1371, in to
2025-11-12T10:57:13.8437154Z stderr F return self._apply(convert)
2025-11-12T10:57:13.8437166Z stderr F ^^^^^^^^^^^^^^^^^^^^
2025-11-12T10:57:13.8437372Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 930, in _apply
2025-11-12T10:57:13.8437594Z stderr F module._apply(fn)
2025-11-12T10:57:13.8437711Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 930, in _apply
2025-11-12T10:57:13.8438220Z stderr F module._apply(fn)
2025-11-12T10:57:13.8438257Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 930, in _apply
2025-11-12T10:57:13.8438266Z stderr F module._apply(fn)
2025-11-12T10:57:13.8438273Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 957, in _apply
2025-11-12T10:57:13.8438280Z stderr F param_applied = fn(param)
2025-11-12T10:57:13.8438288Z stderr F ^^^^^^^^^
2025-11-12T10:57:13.8438295Z stderr F File "/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1364, in convert
2025-11-12T10:57:13.8438303Z stderr F raise NotImplementedError(
2025-11-12T10:57:13.8438311Z stderr F NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

It's probably not related but when i tested with .to("cuda") i had sageattention activated with pipe.transformer.set_attention_backend("sage")

@tolgacangoz
Copy link
Contributor Author

tolgacangoz commented Nov 13, 2025

Hi @zecloud, thanks for reporting this! I will take a look at it tomorrow (+ conflicts below).

@tolgacangoz tolgacangoz force-pushed the integrations/wan2.2-s2v branch from 454769c to a480ecc Compare November 15, 2025 08:31
@tolgacangoz
Copy link
Contributor Author

tolgacangoz commented Nov 15, 2025

Hi @zecloud. AFAIU, there is no Lightning LoRA specifically for the Wan2.2-S2V model. I guess people try to use Wan2.2's high noise transformer's LoRA for S2V? Which one are you using? Could you share reproducible codes?

@zecloud
Copy link

zecloud commented Nov 17, 2025

Hi @tolgacangoz
It's only the high noise Lora I saw that on reddit and wanted to test with your pipeline https://civitai.com/models/1909425/wan-22-14b-s2v-ultimate-suite-gguf-and-lightning-speed-with-extended-video-generation?modelVersionId=2161199

My test code didn't use any quantized version this was your demo loading code with this code to load the lora.
lightning_hn = hf_hub_download(repo_id="lightx2v/Wan2.2-Distill-Loras" , filename="wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors",local_dir = "/pretrained_models")
pipe.load_lora_weights(
lightning_hn,
adapter_name="light"
)
pipe.set_adapters(["light"], adapter_weights=[1.0])
pipe.fuse_lora(adapter_names=["light"], lora_scale=3., components=["transformer"])

I won't able to test it again soon but i let you know if i can.

@tolgacangoz
Copy link
Contributor Author

Without Lora I tried the pipeline with .to("cuda") or pipe.enable_model_cpu_offload() or enable_group_offload and always the same error :

Are you sure that you are using the wan2.2-s2v branch as I emphasized in the first comment?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants